{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# GAN overriding `Model.train_step`\n",
    "\n",
    "**Author:** [fchollet](https://twitter.com/fchollet)<br>\n",
    "**Date created:** 2019/04/29<br>\n",
    "**Last modified:** 2020/04/29<br>\n",
    "**Description:** A simple DCGAN trained using `fit()` by overriding `train_step`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Setup\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "import numpy as np\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Prepare MNIST data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# We use both the training & test MNIST digits.\n",
    "batch_size = 64\n",
    "(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\n",
    "all_digits = np.concatenate([x_train, x_test])\n",
    "all_digits = all_digits.astype(\"float32\") / 255\n",
    "all_digits = np.reshape(all_digits, (-1, 28, 28, 1))\n",
    "dataset = tf.data.Dataset.from_tensor_slices(all_digits)\n",
    "dataset = dataset.shuffle(buffer_size=1024).batch(batch_size).prefetch(32)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Create the discriminator\n",
    "\n",
    "It maps 28x28 digits to a binary classification score.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "discriminator = keras.Sequential(\n",
    "    [\n",
    "        keras.Input(shape=(28, 28, 1)),\n",
    "        layers.Conv2D(64, (3, 3), strides=(2, 2), padding=\"same\"),\n",
    "        layers.LeakyReLU(alpha=0.2),\n",
    "        layers.Conv2D(128, (3, 3), strides=(2, 2), padding=\"same\"),\n",
    "        layers.LeakyReLU(alpha=0.2),\n",
    "        layers.GlobalMaxPooling2D(),\n",
    "        layers.Dense(1),\n",
    "    ],\n",
    "    name=\"discriminator\",\n",
    ")\n",
    "\n",
    "discriminator.summary()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Create the generator\n",
    "\n",
    "It mirrors the discriminator, replacing `Conv2D` layers with `Conv2DTranspose` layers.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "latent_dim = 128\n",
    "\n",
    "generator = keras.Sequential(\n",
    "    [\n",
    "        keras.Input(shape=(latent_dim,)),\n",
    "        # We want to generate 128 coefficients to reshape into a 7x7x128 map\n",
    "        layers.Dense(7 * 7 * 128),\n",
    "        layers.LeakyReLU(alpha=0.2),\n",
    "        layers.Reshape((7, 7, 128)),\n",
    "        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n",
    "        layers.LeakyReLU(alpha=0.2),\n",
    "        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n",
    "        layers.LeakyReLU(alpha=0.2),\n",
    "        layers.Conv2D(1, (7, 7), padding=\"same\", activation=\"sigmoid\"),\n",
    "    ],\n",
    "    name=\"generator\",\n",
    ")\n",
    "\n",
    "generator.summary()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Override `train_step`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "class GAN(keras.Model):\n",
    "    def __init__(self, discriminator, generator, latent_dim):\n",
    "        super(GAN, self).__init__()\n",
    "        self.discriminator = discriminator\n",
    "        self.generator = generator\n",
    "        self.latent_dim = latent_dim\n",
    "\n",
    "    def compile(self, d_optimizer, g_optimizer, loss_fn):\n",
    "        super(GAN, self).compile()\n",
    "        self.d_optimizer = d_optimizer\n",
    "        self.g_optimizer = g_optimizer\n",
    "        self.loss_fn = loss_fn\n",
    "\n",
    "    def train_step(self, real_images):\n",
    "        if isinstance(real_images, tuple):\n",
    "            real_images = real_images[0]\n",
    "        # Sample random points in the latent space\n",
    "        batch_size = tf.shape(real_images)[0]\n",
    "        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n",
    "\n",
    "        # Decode them to fake images\n",
    "        generated_images = self.generator(random_latent_vectors)\n",
    "\n",
    "        # Combine them with real images\n",
    "        combined_images = tf.concat([generated_images, real_images], axis=0)\n",
    "\n",
    "        # Assemble labels discriminating real from fake images\n",
    "        labels = tf.concat(\n",
    "            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0\n",
    "        )\n",
    "        # Add random noise to the labels - important trick!\n",
    "        labels += 0.05 * tf.random.uniform(tf.shape(labels))\n",
    "\n",
    "        # Train the discriminator\n",
    "        with tf.GradientTape() as tape:\n",
    "            predictions = self.discriminator(combined_images)\n",
    "            d_loss = self.loss_fn(labels, predictions)\n",
    "        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)\n",
    "        self.d_optimizer.apply_gradients(\n",
    "            zip(grads, self.discriminator.trainable_weights)\n",
    "        )\n",
    "\n",
    "        # Sample random points in the latent space\n",
    "        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n",
    "\n",
    "        # Assemble labels that say \"all real images\"\n",
    "        misleading_labels = tf.zeros((batch_size, 1))\n",
    "\n",
    "        # Train the generator (note that we should *not* update the weights\n",
    "        # of the discriminator)!\n",
    "        with tf.GradientTape() as tape:\n",
    "            predictions = self.discriminator(self.generator(random_latent_vectors))\n",
    "            g_loss = self.loss_fn(misleading_labels, predictions)\n",
    "        grads = tape.gradient(g_loss, self.generator.trainable_weights)\n",
    "        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))\n",
    "        return {\"d_loss\": d_loss, \"g_loss\": g_loss}\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Create a callback that periodically saves generated images\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "class GANMonitor(keras.callbacks.Callback):\n",
    "    def __init__(self, num_img=3, latent_dim=128):\n",
    "        self.num_img = num_img\n",
    "        self.latent_dim = latent_dim\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))\n",
    "        generated_images = self.model.generator(random_latent_vectors)\n",
    "        generated_images *= 255\n",
    "        generated_images.numpy()\n",
    "        for i in range(self.num_img):\n",
    "            img = keras.preprocessing.image.array_to_img(generated_images[i])\n",
    "            img.save(\"generated_img_{i}_{epoch}.png\".format(i=i, epoch=epoch))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Train the end-to-end model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "epochs = 30\n",
    "\n",
    "gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)\n",
    "gan.compile(\n",
    "    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),\n",
    "    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),\n",
    "    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),\n",
    ")\n",
    "\n",
    "gan.fit(\n",
    "    dataset, epochs=epochs, callbacks=[GANMonitor(num_img=3, latent_dim=latent_dim)]\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Display the last generated images:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image, display\n",
    "\n",
    "display(Image(\"generated_img_0_29.png\"))\n",
    "display(Image(\"generated_img_1_29.png\"))\n",
    "display(Image(\"generated_img_2_29.png\"))\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "dcgan_overriding_train_step",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}